#%%
import torch

class Flatten(torch.nn.Module):
    def forward(self, input):

        batch_size = input.size(0)
        out = input.contiguous().view(batch_size, -1)
        return out
    

class AlexNet(torch.nn.Module):
    def __init__(self, output_dim,device = 'cpu'):
        super().__init__()
        self.device = device
        self.layer = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels = 3,out_channels = 64,kernel_size= 3,stride =  2, padding = 1,bias = False),  # in_channels, out_channels, kernel_size, stride, padding
            torch.nn.BatchNorm2d(64,momentum=0.9),
            torch.nn.MaxPool2d(2),  # kernel_sie
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 192, 3, padding=1,bias = False),
            torch.nn.BatchNorm2d(192,momentum=0.9),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(192, 384, 3, padding=1,bias = False),
            torch.nn.BatchNorm2d(384,momentum=0.9),
            torch.nn.ReLU(),
            torch.nn.Conv2d(384, 256, 3, padding=1),
            torch.nn.BatchNorm2d(256,momentum=0.9),
            torch.nn.ReLU(),
            torch.nn.Conv2d(256, 256, 3, padding=1,bias = False),
            torch.nn.BatchNorm2d(256,momentum=0.9),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(),
            Flatten(),
            torch.nn.Linear(256 * 2 * 2, 256),
            torch.nn.ReLU(),
            # torch.nn.Dropout(0.2),
            torch.nn.Linear(256, output_dim)
        )

    def forward(self, x):
        
        return self.layer(x)
